#' README 
#' the following scripts need to be run for mix in c(0, 1/3), 
#' and for gt_end_months <- c("2014-10", "2015-04", "2015-10", "2016-04", "2016-10", "2017-04", "2017-10", "2018-04", "2018-10", "2019-04", "2019-10", "2020-03")
#' where gt.folder = paste0("pub_history_ending_", gt_end_months, "/")
#' after running the following scripts for the aforementioned setting, the summary is generated by summary.R
#' you can use run.R to run all settings in loop


options(echo=TRUE)
library(xts)
library(glmnet)
library(argo)
library(parallel)
library(data.table)

NCORES = 8

args <- commandArgs(trailingOnly = TRUE)
if(length(args) > 0){
  mix <- as.numeric(args[1])
  gt.folder <- args[2]
}else{
  mix <- 1/3  # FIXME change this to 0 or 1/3
  gt.folder <- "pub_history_ending_2020-03/"  # FIXME change this to Google Trends directory
}


ili.folder <- "ili20200327/"  # FIXME change this to ili directory
population.file <- "Population.csv"  # FIXME change this to population file

reg_data <- load_reg_data(gt.folder=gt.folder,
                          ili.folder=ili.folder,
                          population.file=population.file,
                          gft.file="GFT.txt",  # FIXME change this to GFT file
                          gt.parser = argo:::gt.parser.pub.api)
names(reg_data$GT_state)[names(reg_data$GT_state)=="501"] <- "US.NY.NYC"
names(reg_data$ili_state)[names(reg_data$ili_state)=="US.New York City"] <- "US.NY.NYC"
summary(t(sapply(reg_data$GT_state, dim)))
summary(t(sapply(reg_data$GT_state_filled, dim)))

ili_national <- reg_data$ili_national
ili_regional <- reg_data$ili_regional
ili_state <- reg_data$ili_state
GT_national <- reg_data$GT_national
# CAVEAT: state-level raw data may be of different scale when aggregated to regional
GT_regional <- reg_data$GT_regional
GT_state <- reg_data$GT_state

state_names <- setdiff(names(reg_data$GT_state), c("US", "US.FL"))
state_info <- fread(population.file)
GT_regstate_mixed <- list()
for (each_state in state_names){
  region_id_for_state = state_info[Abbre==strsplit(each_state, "\\.")[[1]][2], Region]
  GT_regstate_mixed[[each_state]] = (reg_data$GT_state[[each_state]] * (1-mix) + GT_regional[[region_id_for_state]] * mix)
}


plot.zoo(ili_state$US.MT)
par(new=TRUE)
plot.zoo(ili_state$US.MT)


#### first-step argo prediction ####
transY <- function(y){
  logit((y+0.1) / 100)
}

inv_transY <- function(y){
  100*logit_inv(y)-0.1
}


get_argo1 <- function(terms, period_nat, period_reg){
  common_idx <- period_nat
  set.seed(1000)
  
  # national 
  GT_national <- merge(GT_national, ili_national)
  GT_national$ili_national <- NULL
  common_idx_nat_append <- c(common_idx[1] - (52:1)*7, common_idx)
  argo_national <- argo(data = transY(ili_national[common_idx_nat_append]),
                        exogen = log(GT_national[common_idx_nat_append, terms]+1),
                        mc.cores = NCORES)
  
  # regional
  argo_regional <- list()
  pred_regional <- list()
  for(region.id in 1:10){
    j <- paste0("Region.", region.id)
    set.seed(1000)
    argo_regional[[j]] <- argo(transY(ili_regional[common_idx, j]),
                               log(GT_regional[[region.id]][common_idx, terms]+1),
                               mc.cores = NCORES,
                               N_lag = NULL)
    
    pred_xts_blend <- inv_transY(argo_regional[[j]]$pred)
    pred_xts_blend <- merge(ili_regional[,j], pred_xts_blend, all=FALSE)
    pred_xts_blend$naive <- c(NA, as.numeric(pred_xts_blend[1:(nrow(pred_xts_blend)-1), j]))
    names(pred_xts_blend)[1] <- "CDC.data"
    pred_regional[[j]] <- pred_xts_blend
    print(j)
  }
  
  # state-level
  common_idx <- period_reg
  argo_state <- list()
  pred_state <- list()
  for(region.id in 1:length(state_names)){
    j <- state_names[region.id]
    set.seed(1000)
    argo_state[[j]] <- argo(transY(ili_state[common_idx, j]),
                            log(GT_regstate_mixed[[j]][common_idx, terms]+1),
                            mc.cores = NCORES,
                            N_lag = NULL)
    # TODO confrim in the final prediction we didn't use the AR3 lag in the state level ARGO
    
    pred_xts_blend <- inv_transY(argo_state[[j]]$pred)
    pred_xts_blend <- merge(ili_state[,j], pred_xts_blend, all=FALSE)
    pred_xts_blend$naive <- c(NA, as.numeric(pred_xts_blend[1:(nrow(pred_xts_blend)-1), j]))
    names(pred_xts_blend)[1] <- "CDC.data"
    pred_state[[j]] <- pred_xts_blend
    print(j)
  }
  list(argo_national=argo_national,
       argo_regional=argo_regional,
       pred_regional=pred_regional,
       argo_state=argo_state,
       pred_state=pred_state)
}

common_idx_nat <- index(merge(ili_national, GT_national, all=FALSE))
# common_idx_nat <- common_idx_nat[common_idx_nat >= "2009-10-10"]
common_idx_reg <- index(merge(ili_state, GT_national, all=FALSE))
terms <- colnames(GT_national)

# first-step argo for two periods
argo1.post10 <- get_argo1(terms, common_idx_nat, common_idx_reg)

#### blend 09 data to all data ####
argo1.coef <- list()
pred_state <- argo1.post10$pred_state

for(j in state_names){
  argo1.coef[[j]] <- argo1.post10$argo_state[[j]]$coef
}
argo.nat.p <- inv_transY(argo1.post10$argo_national$pred)

argo.nat.coef <- argo1.post10$argo_national$coef

#### argo second step ####
argo.reg.p <- lapply(argo1.post10$pred_regional, function(x) x[,"predict"])
argo.reg.p <- do.call(merge, argo.reg.p)
colnames(argo.reg.p) <- names(argo1.post10$pred_regional)

argo.state.p <- lapply(pred_state, function(x) x[,"predict"])
argo.state.p <- do.call(merge, argo.state.p)
colnames(argo.state.p) <- names(pred_state)

argo2 <- function(truth, argo.state.p, argo.reg.p, argo.nat.p, state_names, state_info, use_yt2=TRUE){
  naive.p <- truth
  index(naive.p) <- index(truth) + 7
  common_idx <- index(na.omit(merge(truth, naive.p, argo.state.p, argo.nat.p)))
  
  Y <- truth - naive.p
  Yt2 <- Y
  index(Yt2) <- index(Yt2) + 7

  argo.nat.p <- argo.nat.p[common_idx]
  argo.reg.p <- argo.reg.p[common_idx]
  naive.p <- naive.p[common_idx]
  truth <- truth[common_idx]
  argo.state.p <- argo.state.p[common_idx]
  
  X <- argo.state.p - naive.p
  
  X.nat <- as.numeric(argo.nat.p) - naive.p
  X.nat <- X.nat[common_idx]
  
  argo.reg.p.dup <- lapply(state_names, function(each_state){
    region_id_for_state = state_info[Abbre==strsplit(each_state, "\\.")[[1]][2], Region]
    argo.reg.p[,region_id_for_state]
  })
  argo.reg.p.dup <- do.call(merge, argo.reg.p.dup)
  names(argo.reg.p.dup) <- state_names
  X.reg <- argo.reg.p.dup - naive.p
  
  projection.mat <- list()
  mean.mat <- list()
  
  Y.pred <- X
  Y.pred[] <- NA
  
  zw_used <- list()
  sigma_ww.structured <- sigma_ww.empirical <-
    sigma_zw.structured <- sigma_zw.empirical <-
    heat.vec.structured <-
    sigma_zwzw.structured <- sigma_zwzw.empirical <- list()
  
  for(it in 105:length(common_idx)){
    training_idx <- common_idx[(it-104):(it-1)]
    t.now <- common_idx[it]
    y <- Y[training_idx,]
    x <- X[training_idx,]
    x.nat <- X.nat[training_idx,]
    x.reg <- X.reg[training_idx,]
    yt2 <- Yt2[training_idx,]
    
    sigma_yy <- var(y)
    
    m1 <- cor(y, yt2)
    m2 <- cor(y)
    rho.l2 <- sum(m1*m2)/sum(m2^2)
    
    autocov.y.yt2 <- rho.l2*sigma_yy
    
    
    # vcov.x_xnat <- 
    #   cbind(rbind(sigma_yy+diag(diag(var((argo.state.p - truth)[training_idx,]))),sigma_yy),
    #         rbind(sigma_yy,sigma_yy+var((as.numeric(argo.nat.p) - truth)[training_idx,])))
    vcov.x_xreg_xnat <- 
      cbind(rbind(sigma_yy+var((argo.state.p - truth)[training_idx,]),sigma_yy, sigma_yy),
            rbind(sigma_yy,sigma_yy+var((argo.reg.p.dup - truth)[training_idx,]), sigma_yy),
            rbind(sigma_yy,sigma_yy,sigma_yy+var((as.numeric(argo.nat.p) - truth)[training_idx,])))
    sigma_zw <- cbind(sigma_yy,sigma_yy,sigma_yy)
    
    if(use_yt2){
      vcov.x_xreg_xnat <- cbind(vcov.x_xreg_xnat, rbind(autocov.y.yt2,autocov.y.yt2,autocov.y.yt2))
      vcov.x_xreg_xnat <- rbind(vcov.x_xreg_xnat, cbind(t(autocov.y.yt2),t(autocov.y.yt2),t(autocov.y.yt2),sigma_yy))
      sigma_zw <- cbind(sigma_zw, autocov.y.yt2)
    }
    
    # not shrinked
    if(use_yt2){
      y.pred.blp <- colMeans(y) +
        sigma_zw %*% 
        solve(vcov.x_xreg_xnat, c(t(X[t.now,])-colMeans(x), t(X.reg[t.now,])-colMeans(x.reg), t(X.nat[t.now,])-colMeans(x.nat), t(Yt2[t.now,]-colMeans(yt2))))
    }else{
      y.pred.blp <- colMeans(y) +
        sigma_zw %*% 
        solve(vcov.x_xreg_xnat, c(t(X[t.now,])-colMeans(x), t(X.reg[t.now,])-colMeans(x.reg), t(X.nat[t.now,])-colMeans(x.nat)))
    }
    
    Kzz <- solve((1-rho.l2^2)*sigma_yy)
    Kgt <- solve(var((argo.state.p - truth)[training_idx,]))
    Kreg <- solve(var((argo.reg.p.dup - truth)[training_idx,]))
    Knat <- solve(var((as.numeric(argo.nat.p) - truth)[training_idx,]))
    
    if(use_yt2){
      y.pred.bayes <- colMeans(y) + 
        solve(Kzz+Knat+Kreg+Kgt,
              Knat%*%(t(X.nat[t.now,])-colMeans(x.nat)) + Kgt%*%(t(X[t.now,])-colMeans(x)) + 
                Kreg%*%(t(X.reg[t.now,])-colMeans(x.reg)) + 
                Kzz%*%(rho.l2*(t(Yt2[t.now,]-colMeans(yt2)))))
    }else{
      y.pred.bayes <- colMeans(y) + 
        solve(solve(sigma_yy) + Knat+Kreg+Kgt,  # prior of y is mean 0 vcov sigma_yy
              Knat%*%(t(X.nat[t.now,])-colMeans(x.nat)) + Kgt%*%(t(X[t.now,])-colMeans(x)) + 
                Kreg%*%(t(X.reg[t.now,])-colMeans(x.reg)))
    }
    
    if(all(is.finite(y.pred.blp))){
      stopifnot(all(abs(y.pred.blp-y.pred.bayes) < 1e-8))
    }
    
    # shrinked  
    if(use_yt2){
      # if use empirical 50x50 matrix with shrinkage - result is 17% worse
      # vcov.x_xreg_xnat <- var(cbind(X,X.reg,X.nat,Yt2)[training_idx,])
      # sigma_zw <- cov(Y[training_idx,], cbind(X,X.reg,X.nat,Yt2)[training_idx,])
      
      # if only state prediction use empirical, result is very very bad (explode)
      # state_index <- array(FALSE, dim = dim(vcov.x_xreg_xnat))
      # state_index[1:51,] <- TRUE
      # state_index[,1:51] <- TRUE
      # vcov.x_xreg_xnat[state_index] <- var(cbind(X,X.reg,X.nat,Yt2)[training_idx,])[state_index]
      # sigma_zw[, 1:51] <- cov(Y[training_idx,], cbind(X,X.reg,X.nat,Yt2)[training_idx,])[, 1:51]

      # if state prediction use structured, result is very bad (half-explode)
      # vcov.x_xreg_xnat[1:51, 1:51] <- var(X)
      # sigma_zw[, 1:51] <- cov(Y[training_idx,], X[training_idx,])
      # vcov.x_xreg_xnat[1:51, 52:102] <- t(sigma_zw[, 1:51])
      # vcov.x_xreg_xnat[1:51, 103:153] <- t(sigma_zw[, 1:51])
      # vcov.x_xreg_xnat[52:102, 1:51] <- sigma_zw[, 1:51]
      # vcov.x_xreg_xnat[103:153, 1:51] <- sigma_zw[, 1:51]
      # vcov.x_xreg_xnat[1:51, 154:204] <- cov(X[training_idx,], Yt2[training_idx,])
      # vcov.x_xreg_xnat[154:204, 1:51] <- cov(Yt2[training_idx,], X[training_idx,])
      # nsize=44
      # sum(vcov.x_xreg_xnat[(2*nsize+1):(3*nsize), (3*nsize+1):(4*nsize)])
      
      y.pred <- colMeans(y) +
        sigma_zw %*% 
        solve(vcov.x_xreg_xnat + diag(diag(var(cbind(X,X.reg,X.nat,Yt2)[training_idx,]))), 
              c(t(X[t.now,])-colMeans(x), t(X.reg[t.now,])-colMeans(x.reg), t(X.nat[t.now,])-colMeans(x.nat), t(Yt2[t.now,]-colMeans(yt2))))
    }else{
      y.pred <- colMeans(y) +
        sigma_zw %*% 
        solve(vcov.x_xreg_xnat + diag(diag(var(cbind(X,X.reg,X.nat)[training_idx,]))), 
              c(t(X[t.now,])-colMeans(x), t(X.reg[t.now,])-colMeans(x.reg), t(X.nat[t.now,])-colMeans(x.nat)))
    }
    
    Y.pred[t.now, ] <- t(y.pred)
    projection.mat[[as.character(t.now)]] <- sigma_zw %*% solve(vcov.x_xreg_xnat + diag(diag(var(cbind(X,X.reg,X.nat,Yt2)[training_idx,]))))
    mean.mat[[as.character(t.now)]] <- c(colMeans(y), colMeans(x), colMeans(x.reg), colMeans(x.nat), colMeans(yt2))
    
    sigma_ww <- vcov.x_xreg_xnat
    sigma_zz <- sigma_yy
    sigma_ww.structured[[as.character(t.now)]] <- sigma_ww
    sigma_ww.empirical[[as.character(t.now)]] <- var(cbind(X,X.reg,X.nat,Yt2)[training_idx,])
    sigma_zw.structured[[as.character(t.now)]] <- sigma_zw
    sigma_zw.empirical[[as.character(t.now)]] <- cov(Y[training_idx,], cbind(X,X.reg,X.nat,Yt2)[training_idx,])
    
    sigma_zwzw.structured[[as.character(t.now)]] <- rbind(
      cbind(sigma_zz, sigma_zw),
      cbind(t(sigma_zw), sigma_ww)
    )
    sigma_zwzw.empirical[[as.character(t.now)]] <- var(cbind(Y,X,X.reg,X.nat,Yt2)[training_idx,])
    zw_used[[as.character(t.now)]] <- cbind(Y,X,X.reg,X.nat,Yt2)[training_idx,]
  }
  
  projection.mat <- sapply(projection.mat, identity, simplify = "array")
  mean.mat <- sapply(mean.mat, identity, simplify = "array")
  
  argo2.p <- Y.pred + naive.p
  
  err.twostep <- argo2.p - truth
  

  heat.vec <- na.omit(merge(truth-naive.p, argo.state.p - truth, argo.reg.p.dup - truth, as.numeric(argo.nat.p) - truth, Yt2))
  colnames(heat.vec) <- paste0(rep(c("CDC.increment.", "err.argo.", "err.reg.", "err.nat.", "err.y2."), each=length(state_names)),
                               state_names)
  
  sigma_ww.structured <- sapply(sigma_ww.structured, identity, simplify = "array")
  sigma_ww.empirical <- sapply(sigma_ww.empirical, identity, simplify = "array")
  sigma_zw.structured <- sapply(sigma_zw.structured, identity, simplify = "array")
  sigma_zw.empirical <- sapply(sigma_zw.empirical, identity, simplify = "array")
  sigma_zwzw.structured <- sapply(sigma_zwzw.structured, identity, simplify = "array")
  sigma_zwzw.empirical <- sapply(sigma_zwzw.empirical, identity, simplify = "array")
  zw_used <- sapply(zw_used, identity, simplify = "array")
  
  list(onestep=argo.state.p, twostep=argo2.p, naive=naive.p, truth=truth,
       Y.pred=Y.pred, err.twostep=err.twostep,
       heat.vec=heat.vec, projection.mat=projection.mat, mean.mat=mean.mat,
       sigma_ww.structured=sigma_ww.structured, sigma_ww.empirical=sigma_ww.empirical,
       sigma_zw.structured=sigma_zw.structured, sigma_zw.empirical=sigma_zw.empirical,
       sigma_zwzw.structured=sigma_zwzw.structured, sigma_zwzw.empirical=sigma_zwzw.empirical,
       zw_used=zw_used, zw_overall = cbind(Y,X,X.reg,X.nat,Yt2))
}

argo_ind <- function(truth, argo.state.p, argo.reg.p, argo.nat.p, state_names, state_info, use_reg=FALSE){
  naive.p <- truth
  index(naive.p) <- index(truth) + 7
  common_idx <- index(na.omit(merge(truth, naive.p, argo.state.p, argo.nat.p)))
  
  Y <- truth - naive.p
  Yt2 <- Y
  index(Yt2) <- index(Yt2) + 7
  
  argo.nat.p <- argo.nat.p[common_idx]
  argo.reg.p <- argo.reg.p[common_idx]
  naive.p <- naive.p[common_idx]
  truth <- truth[common_idx]
  argo.state.p <- argo.state.p[common_idx]
  
  X <- argo.state.p - naive.p
  
  X.nat <- as.numeric(argo.nat.p) - naive.p
  X.nat <- X.nat[common_idx]
  
  argo.reg.p.dup <- lapply(state_names, function(each_state){
    region_id_for_state = state_info[Abbre==strsplit(each_state, "\\.")[[1]][2], Region]
    argo.reg.p[,region_id_for_state]
  })
  argo.reg.p.dup <- do.call(merge, argo.reg.p.dup)
  names(argo.reg.p.dup) <- state_names
  X.reg <- argo.reg.p.dup - naive.p
  
  Y.pred <- X
  Y.pred[] <- NA
  
  for(it in 105:length(common_idx)){
    for(state_id in 1:ncol(argo.state.p)){
      training_idx <- common_idx[(it-104):(it-1)]
      t.now <- common_idx[it]
      y <- Y[training_idx,state_id]
      x <- X[training_idx,state_id]
      x.nat <- X.nat[training_idx,state_id]
      x.reg <- X.reg[training_idx,state_id]
      yt2 <- Yt2[training_idx,state_id]
      
      if(use_reg){
        sigma_zw <- cov(y, cbind(x, x.reg, x.nat, yt2))
        vcov.x_xreg_xnats <- var(cbind(x, x.reg, x.nat, yt2))
        
        y.pred <- colMeans(y) +
          sigma_zw %*% 
          solve(vcov.x_xreg_xnats + diag(diag(vcov.x_xreg_xnats)), 
                c(X[t.now,state_id]-mean(x), X.reg[t.now,state_id]-mean(x.reg), X.nat[t.now,state_id]-mean(x.nat), Yt2[t.now,state_id]-mean(yt2)))
      }else{
        sigma_zw <- cov(y, cbind(x, x.nat, yt2))
        vcov.x_xreg_xnats <- var(cbind(x, x.nat, yt2))
        
        y.pred <- colMeans(y) +
          sigma_zw %*% 
          solve(vcov.x_xreg_xnats + diag(diag(vcov.x_xreg_xnats)), 
                c(X[t.now,state_id]-mean(x), X.nat[t.now,state_id]-mean(x.nat), Yt2[t.now,state_id]-mean(yt2)))
        
      }
      Y.pred[t.now, state_id] <- y.pred
    }
  }

  argo2.p <- Y.pred + naive.p
  err.twostep <- argo2.p - truth
  
  list(onestep=argo.state.p, twostep=argo2.p, naive=naive.p, truth=truth,
       Y.pred=Y.pred, err.twostep=err.twostep)
}

argo2_result <- argo2(ili_state[,state_names], argo.state.p, argo.reg.p[index(argo.state.p)], argo.nat.p[index(argo.state.p)], state_names, state_info, TRUE)
state_names_sub <- setdiff(state_names, c("US.HI", "US.AK"))
state_names_sub <- setdiff(state_names_sub, c("US.ME", "US.MT", "US.ND", "US.SD", "US.VT"))
argo2_result_sub <- argo2(ili_state[,state_names_sub], 
                      argo.state.p[,state_names_sub], 
                      argo.reg.p[index(argo.state.p)], 
                      argo.nat.p[index(argo.state.p)], 
                      state_names_sub, state_info, TRUE)
argo_ind_result <- argo_ind(ili_state[,state_names], argo.state.p, argo.reg.p[index(argo.state.p)], argo.nat.p[index(argo.state.p)], state_names, state_info, use_reg = FALSE)

outDir <- paste0("argo2_mix", round(mix, 2))
dir.create(outDir, showWarnings = FALSE, recursive = TRUE)
save.image(paste0(outDir, "/argo2-state-GT-",tail(strsplit(gt.folder, "/")[[1]], 1),".rda"))
